from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

import numpy as np

# ============================================
# AFFICHAGE DE LA TABLE
# ============================================
def AfficheTable(L, g):
    """
    Affiche la table de programmation dynamique pour Floyd-Warshall.
    Pour chaque k, affiche une matrice n×n des distances.
    Maximum 3 matrices par ligne.
    """
    sommets = list(g.keys())
    n = len(sommets)

    # Trouver le k max dans L
    k_max = max(key[0] for key in L.keys()) if L else 0

    # Calculer le nombre de lignes et colonnes (max 3 colonnes)
    nb_matrices = k_max + 1
    nb_cols = min(3, nb_matrices)
    nb_rows = (nb_matrices + nb_cols - 1) // nb_cols  # Arrondi supérieur

    # Créer une figure avec une grille de sous-graphiques
    fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(3.5 * nb_cols, 3.5 * nb_rows))

    # Convertir axes en tableau 2D si nécessaire
    if nb_matrices == 1:
        axes = np.array([[axes]])
    elif nb_rows == 1:
        axes = np.array([axes])
    elif nb_cols == 1:
        axes = axes.reshape(-1, 1)

    for k in range(k_max + 1):
        row = k // nb_cols
        col = k % nb_cols
        ax = axes[row, col]

        # Créer une matrice RGB pour ce k
        mat = np.zeros((n, n, 3))

        for i, v in enumerate(sommets):
            for j, w in enumerate(sommets):
                if (k, v, w) in L:
                    mat[i][j] = [0.2, 0.7, 0.3]  # Vert
                else:
                    mat[i][j] = [0.85, 0.85, 0.85]  # Gris clair

        ax.imshow(mat)

        # Afficher les valeurs dans chaque case
        for i, v in enumerate(sommets):
            for j, w in enumerate(sommets):
                if (k, v, w) in L:
                    valeur = L[(k, v, w)]
                    if valeur == float('inf'):
                        txt = '∞'
                    else:
                        txt = str(int(valeur))
                    ax.text(j, i, txt, ha='center', va='center',
                            color='white', fontsize=8, fontweight='bold')

        # Configurer les axes
        ax.set_xticks(range(n))
        ax.set_xticklabels(sommets, fontsize=8)
        ax.set_yticks(range(n))
        ax.set_yticklabels(sommets, fontsize=8)
        ax.set_xlabel('Destination (w)', fontsize=8)
        ax.set_ylabel('Origine (v)', fontsize=8)
        ax.set_title(f'k = {k}', fontsize=10)

        # Quadrillage
        ax.set_xticks(np.arange(-0.5, n, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, n, 1), minor=True)
        ax.grid(which='minor', color='black', linewidth=0.5)

    # Masquer les axes inutilisés (si nb_matrices n'est pas multiple de nb_cols)
    for k in range(nb_matrices, nb_rows * nb_cols):
        row = k // nb_cols
        col = k % nb_cols
        axes[row, col].axis('off')

    plt.suptitle('Table de programmation dynamique Floyd-Warshall', fontsize=12)
    plt.tight_layout()
    plt.show()


# Graphe représenté par un dictionnaire d'adjacence
# graphe[v] = [(w1, poids1), (w2, poids2), ...]
graphe = {
    1: [(2, 2), (3, 4)],
    2: [(3, -1), (4, 2)],
    3: [(4, 3), (5, 4)],
    4: [(5, 2)],
    5: []
}


# Variante avec cycle négatif (pour tests)
graphe_neg = {
    1: [(2, 4), (3, 2)],
    2: [(4, 3), (5, 4)],
    3: [(2, -1), (4, 2), (5,4)],
    4: [(2, -5), (5, 2)],
    5: []
}





























